# Author: @悠悠智汇笔记
# Created: 2025-07-25

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


def create_marginal_boxplot(data_path, main_x='displ', main_y='hwy',
                            size_var='cty', color_var='manufacturer',
                            main_title='Scatterplot with Histograms',
                            boxplot_colors=('#1f77b4', '#ff7f0e'),
                            figsize=(12, 8), dpi=80):
    """
    创建带有边际箱线图的散点图

    参数:
    data_path (str): 数据文件路径
    main_x (str): 主图x轴变量名
    main_y (str): 主图y轴变量名
    size_var (str): 散点大小变量名
    color_var (str): 散点颜色变量名
    main_title (str): 主图标题
    boxplot_colors (tuple): (右侧箱线图颜色, 底部箱线图颜色)
    figsize (tuple): 图形尺寸
    dpi (int): 图形分辨率
    """
    # 导入数据
    df = pd.read_csv(data_path)

    # 创建图形和网格布局
    fig = plt.figure(figsize=figsize, dpi=dpi)
    grid = plt.GridSpec(4, 4, hspace=0.5, wspace=0.2)

    # 定义坐标轴
    ax_main = fig.add_subplot(grid[:-1, :-1])  # 主图
    ax_right = fig.add_subplot(grid[:-1, -1], xticklabels=[], yticklabels=[])  # 右侧箱线图
    ax_bottom = fig.add_subplot(grid[-1, 0:-1], xticklabels=[], yticklabels=[])  # 底部箱线图

    # 主图上的散点图
    ax_main.scatter(main_x, main_y,
                    s=df[size_var] * 5,
                    c=df[color_var].astype('category').cat.codes,
                    alpha=.8, data=df, cmap="Set1",
                    edgecolors='black', linewidths=.5)

    # 添加带颜色的箱线图
    sns.boxplot(df[main_y], ax=ax_right, orient="v", color=boxplot_colors[0])
    sns.boxplot(df[main_x], ax=ax_bottom, orient="h", color=boxplot_colors[1])

    # 图形修饰
    ax_bottom.set(xlabel='')
    ax_right.set(ylabel='')

    # 设置标题和标签
    ax_main.set(title=f'{main_title}\n{main_x} vs {main_y}',
                xlabel=main_x, ylabel=main_y)

    # 设置字体大小
    ax_main.title.set_fontsize(20)
    for item in ([ax_main.xaxis.label, ax_main.yaxis.label] +
                 ax_main.get_xticklabels() + ax_main.get_yticklabels()):
        item.set_fontsize(14)

    return fig


# 使用示例
fig = create_marginal_boxplot(
    data_path="mpg_ggplot2.csv",
    boxplot_colors=('#2ca02c', '#d62728'),  # 修改箱线图颜色为绿色和红色
    main_title='Engine Displacement vs Highway Mileage'
)

# 保存和显示
plt.savefig('marginal_boxplot_custom.png', dpi=300, bbox_inches='tight')
plt.show()